syntax = "proto3";

package tensorflow.tpu;

import "tensorflow/core/protobuf/tpu/optimization_parameters.proto";
import "tensorflow/core/protobuf/tpu/tpu_embedding_output_layout.proto";

message TPUEmbeddingConfiguration {
  // Description of the various embedding tables.
  message TableDescriptor {
    // Name of the table.
    string name = 1;
    // Size of the vocabulary (i.e., number of rows) in the table.
    int32 vocabulary_size = 2;
    // The embedding dimension (i.e., the width of the embedding table).
    int32 dimension = 3;
    // Number of features mapped to this table.
    int32 num_features = 4;
    // Details of the learning algorithm used to update the embedding
    // parameters.
    OptimizationParameters optimization_parameters = 5;
  }
  repeated TableDescriptor table_descriptor = 1;

  // Mode. Should the embedding layer program be run for inference (just forward
  // pass), training (both forward and backward pass) or just the backward_pass.
  enum Mode {
    UNSPECIFIED = 0;
    INFERENCE = 1;
    TRAINING = 2;
    BACKWARD_PASS_ONLY = 3;
  }
  Mode mode = 2;

  // Number of samples in each batch of embedding layer activations sent to
  // the TensorCore.
  int32 batch_size_per_tensor_core = 3;

  // Number of TPU hosts used for inference/training.
  int32 num_hosts = 4;

  // Number of TensorCore used for inference/training.
  int32 num_tensor_cores = 5;

  // Sharding strategy of the embedding tables among the hosts.
  // If the sharding_strategy is "mod", each id is assigned to host
  // "id % num_hosts". For instance, 13 ids are split across 5 hosts as:
  // [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]].
  // If the sharding_strategy is "div", ids are assigned to hosts in a
  // contiguous manner. In this case, 13 ids are split across 5 hosts as:
  // [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]].
  // In both the strategies, if the id space does not evenly divide the number
  // of hosts, each of the first "table_descriptor.vocabulary_size % num_hosts"
  // hosts will be assigned one more id.
  // This partitioning strategy exactly follows that in the embedding_lookup
  // TensorFlow function at tensorflow/python/ops/embedding_ops.py.
  enum ShardingStrategy {
    DIV_DEFAULT = 0;
    MOD = 1;
  }
  ShardingStrategy sharding_strategy = 6;

  // This parameter determines if the execution of the sparse core will be
  // pipelined with that of the TensorCore. This parameter only affects results
  // when mode=TRAINING. If mode=INFERENCE or BACKWARD_PASS_ONLY, this parameter
  // does not affect execution and hence, is a don't care value.
  //
  // false: The execution of the sparse core is not pipelined with that of the
  // TensorCore. The forward pass of every step on the sparse core is executed
  // only after the backward pass of the previous step is complete. And the
  // backward pass on the sparse core is executed only after the embedding
  // gradients have been computed on the TensorCore on every step. This ensures
  // that the activations on every step observe the gradient updates from the
  // previous step on both the sparse core and the TensorCore.
  //
  // true: The execution of the sparse core is pipelined with that of the
  // TensorCore. The forward pass of every step on the sparse core can be
  // executed after the forward pass of the previous step is complete without
  // waiting for the backward pass. This improves the utilization of the sparse
  // core allowing it to process step N+1 while the embedding gradients for step
  // N are computed on the TensorCore. The backward pass of every step on the
  // sparse core is executed directly after the forward pass for the next step
  // is complete. The drawback is that embedding activations for step N+1 do not
  // observe the embedding gradient updates from step N. This could affect model
  // quality if step N and N+1 involve the same set of embedding IDs. However,
  // since the embedding updates are sparse, this is generally not considered a
  // problem.
  bool pipeline_execution_with_tensor_core = 7;

  // Extended output layout information; if not provided, a compatibility mode
  // will use defaults that match the old layout. Providing a value for this
  // field is EXPERIMENTAL and most ways of filling it will probably break. Do
  // not set it unless you know what you are doing.
  TPUEmbeddingOutputLayout output_layout = 8;
}
